/**
 * @author LKQ
 * @date 2022/2/22 20:14
 * @description
 */
public class Solution {
    public static void main(String[] args) {
        Solution solution = new Solution();
        TreeNode t3 = new TreeNode(3, new TreeNode(4), new TreeNode(5)), t2 = new TreeNode(2, t3, new TreeNode(6)),
                t7 = new TreeNode(7, new TreeNode(8), new TreeNode(9)), t1 = new TreeNode(1, t2, t7);
        solution.flatten(t1);
    }
    public void flatten(TreeNode root) {
        if (root == null) {
            return;
        }
        flatten(root.left);
        TreeNode right = root.right;
        root.right = root.left;
        root.left = null;
        while ( root.right != null) {
            root = root.right;
        }
        flatten(right);
        root.right = right;
    }
}
